{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2ca35392",
   "metadata": {},
   "source": [
    "# Targeted Quantization with Quanto"
   ]
  },
  {
   "cell_type": "raw",
   "id": "68381174",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/PrunaAI/pruna/blob/v|version|/docs/tutorials/target_modules_quanto.ipynb\">\n",
    "    <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "raw",
   "id": "6be5627c",
   "metadata": {
    "raw_mimetype": "text/restructuredtext",
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "This tutorial demonstrates how to use :doc:`Target Modules </docs_pruna/user_manual/power_user>` hyperparameters to specify which modules an algorithm should be applied to. We will use Quanto for this demonstration which has such a ``target_modules`` hyperparameter."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5755afb2",
   "metadata": {},
   "source": [
    "### Getting Started"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7834d46",
   "metadata": {},
   "outputs": [],
   "source": [
    "# if you are not running the latest version of this tutorial, make sure to install the matching version of pruna\n",
    "# the following command will install the latest version of pruna\n",
    "%pip install pruna"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dad6c69",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4445e12e",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "source": [
    "### 1. Loading a Model"
   ]
  },
  {
   "cell_type": "raw",
   "id": "2fc8eed0",
   "metadata": {
    "raw_mimetype": "text/restructuredtext",
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "We will demonstrate this feature using Flux.\n",
    "Alternatively, you could use an LLM instead of an image generation model and adapt the loading, configuration, and evaluation steps by following :doc:`this tutorial </docs_pruna/tutorials/llm_quantization_compilation_acceleration>`, for example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73d954b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from diffusers import DiffusionPipeline\n",
    "\n",
    "model_id = \"black-forest-labs/FLUX.1-dev\"\n",
    "pipe = DiffusionPipeline.from_pretrained(\n",
    "    pretrained_model_name_or_path=model_id,\n",
    "    torch_dtype=torch.bfloat16,\n",
    ")\n",
    "pipe = pipe.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b2ff44e",
   "metadata": {},
   "source": [
    "We'll generate an image to compare it to the quantized model later"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ca7e4b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"A crow flying over snowy mountains with a vibrant green valley below and warm colors from the sunset. High resolution, realistic style.\"\n",
    "img = pipe(prompt, generator=torch.Generator(device=device).manual_seed(42)).images[0]\n",
    "img.save(\"original.png\")\n",
    "img"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18b4c694",
   "metadata": {},
   "source": [
    "### 2. Define a SmashConfig with Quanto and Smash the Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9de62754",
   "metadata": {},
   "source": [
    "A module is select if it matches at least one include pattern and none of the exclude patterns."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "504d0d92",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pruna import SmashConfig, smash\n",
    "\n",
    "smash_config = SmashConfig({\"quanto\": {\"weight_bits\": \"qint4\"}})\n",
    "pipe = smash(pipe, smash_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6854d4ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "img = pipe(prompt, generator=torch.Generator(device=device).manual_seed(42)).images[0]\n",
    "img.save(\"smashed.png\")\n",
    "img"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1af03b63",
   "metadata": {},
   "source": [
    "We're using a 4-bit quantization, which is pretty aggressive with quanto. The quality is still there but the image is different from the original.\n",
    "To make it closer, we'll exclude sensitive parts of the layer from the quantization."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e110d97f",
   "metadata": {},
   "source": [
    "### 3. Load the Original Model again"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae629097",
   "metadata": {},
   "source": [
    "First let's clear the first model to free memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee82542d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pruna.engine.utils import safe_memory_cleanup\n",
    "\n",
    "pipe.destroy()\n",
    "del pipe\n",
    "safe_memory_cleanup()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06962b1e",
   "metadata": {},
   "source": [
    "Now we can load the original model again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23d806a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = DiffusionPipeline.from_pretrained(\n",
    "    pretrained_model_name_or_path=model_id,\n",
    "    torch_dtype=torch.bfloat16,\n",
    ")\n",
    "pipe = pipe.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f868185a",
   "metadata": {},
   "source": [
    "### 4. Smash the Model again with Target Modules"
   ]
  },
  {
   "cell_type": "raw",
   "id": "14ffad65",
   "metadata": {
    "raw_mimetype": "text/restructuredtext",
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "Some algorithm such as Quanto offer the option of specifying a :doc:`target_modules </docs_pruna/user_manual/power_user>` hyperparameter.\n",
    "This allows you to choose which modules the algorithm should be applied to by providing unix-shell style patterns of modules to include or exclude."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "081bc8ef",
   "metadata": {},
   "source": [
    "As before, we use 4-bit quantization from Quanto:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31ced6bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "smash_config = SmashConfig({\"quanto\": {\"weight_bits\": \"qint4\"}})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2974373f",
   "metadata": {},
   "source": [
    "Here, we'll avoid quantizing layers related to the embeddings which are sensitive to quantization.\n",
    "\n",
    "Note that we are applying quantization to fewer parts of the model, which means that the smashed model will need more VRAM than the fully quantized version.\n",
    "However, the selection below only excludes 0.4% of the parameters, so the overhead should be manageable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d7ef30a",
   "metadata": {},
   "outputs": [],
   "source": [
    "smash_config.add({\"quanto_target_modules\": {\n",
    "    \"include\": [\"transformer.*\"],  # here we consider any module in the model's unet\n",
    "    \"exclude\": [\"*embed*\"],  # and exclude any module containing \"embed\" in its path\n",
    "    # you can add other patterns in the include or exclude lists\n",
    "}})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be049e96",
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = smash(pipe, smash_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c792e19b",
   "metadata": {},
   "outputs": [],
   "source": [
    "img = pipe(prompt, generator=torch.Generator(device=device).manual_seed(42)).images[0]\n",
    "img.save(\"smashed_target_modules.png\")\n",
    "img"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb24a47a",
   "metadata": {},
   "source": [
    "### 5. Compare the Results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70245373",
   "metadata": {},
   "source": [
    "Let's load the images we generated and show them side by side."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4fef931",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure(figsize=(15, 5))\n",
    "for i, img_name in enumerate([\"original.png\", \"smashed.png\", \"smashed_target_modules.png\"]):\n",
    "    plt.subplot(1, 3, i + 1)\n",
    "    plt.imshow(plt.imread(img_name))\n",
    "    plt.title(img_name)\n",
    "    plt.axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "157a3b44",
   "metadata": {},
   "source": [
    "Much better! Although the fully-quantized model produced an image close to the original, the one with the `target_modules` option is even closer."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6cf329c5",
   "metadata": {},
   "source": [
    "### Wrap up"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd744e68",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "Congratulation! You have successfully smashed a model with a fine-grained control over which modules are quantized.\n",
    "\n",
    "You can play around excluding different parts of the layer and see their effect on the end result, on the VRAM usage and on the speed of the model.\n",
    "We plan to add this option to other quantizers so make sure to check if your favorite algorithm already has a `target_modules` hyperparameter!\n"
   ]
  },
  {
   "cell_type": "raw",
   "id": "5f2f7b93",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "Make sure to check out other :doc:`pruna tutorials </docs_pruna/tutorials/index>` to juice all the speed out of your model!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Pruna",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
